def solve(pattern, s):
    m = len(pattern)
    s_list = s.split()
    n = len(s_list)
    if m != n:
        return False

    p_s = {}
    s_p = {}
    for i in range(m):
        p_s[pattern[i]] = s_list[i]

    for i in range(n):
        s_p[s_list[i]] = pattern[i]

    if list(p_s.keys()) != list(s_p.values()):
        return False

    return True


if __name__ == "__main__":
    pattern = "abba"
    s = "dog cat cat dog"
    print(solve(pattern, s))
